"""LP Solver for two-player zero-sum games."""

import cvxopt
import numpy as np
import utils
import pyspiel

# Constants that determine the type of objective (max vs. min) and type of
# constraints (<=, >=, =).
OBJ_MAX = 1
OBJ_MIN = 2
CONS_TYPE_LEQ = 3
CONS_TYPE_GEQ = 4
CONS_TYPE_EQ = 5

# Constants that determine the type of dominance to find.
DOMINANCE_STRICT = 1
DOMINANCE_VERY_WEAK = 2
DOMINANCE_WEAK = 3


class _Variable(object):
    """A variable in an LP."""

    def __init__(self, vid, lb=None, ub=None):
        """Creates a variable in a linear program.

        Args:
          vid: (integer) the variable id (should be unique for each variable)
          lb: the lower bound on the variable's value (None means no lower bound)
          ub: the upper bound on the variable's valie (None means no upper bound)
        """
        self.vid = vid
        self.lb = lb
        self.ub = ub


class _Constraint(object):
    """A constraint in an LP."""

    def __init__(self, cid, ctype):
        """Creates a constraint in a linear program.

        Args:
          cid: (integer) the constraint id (should be unique for each constraint)
          ctype: the constraint type (CONS_TYPE_{LEQ, GEQ, EQ})
        """
        self.cid = cid
        self.ctype = ctype
        self.coeffs = {}  # var label -> value
        self.rhs = None


class LinearProgram(object):
    """A object used to provide a user-friendly API for building LPs."""

    def __init__(self, objective):
        assert objective == OBJ_MIN or objective == OBJ_MAX
        self._valid_constraint_types = [CONS_TYPE_EQ, CONS_TYPE_LEQ, CONS_TYPE_GEQ]
        self._objective = objective
        self._obj_coeffs = {}  # var label -> value
        self._vars = {}  # var label -> var
        self._cons = {}  # cons label -> constraint
        self._var_list = []
        self._leq_cons_list = []
        self._eq_cons_list = []

    def add_or_reuse_variable(self, label, lb=None, ub=None):
        """Adds a variable to this LP, or reuses one if the label exists.

        If the variable already exists, simply checks that the upper and lower
        bounds are the same as previously specified.

        Args:
          label: a label to assign to this constraint
          lb: a lower-bound value for this variable
          ub: an upper-bound value for this variable
        """
        var = self._vars.get(label)
        if var is not None:
            # Do not re-add, but ensure it's the same
            assert var.lb == lb and var.ub == ub
            return
        var = _Variable(len(self._var_list), lb, ub)
        self._vars[label] = var
        self._var_list.append(var)

    def add_or_reuse_constraint(self, label, ctype):
        """Adds a constraint to this LP, or reuses one if the label exists.

         If the constraint is already present, simply checks it's the same type as
         previously specified.

        Args:
          label: a label to assign to this constraint
          ctype: the constraint type (in CONS_TYPE_{LEQ,GEQ,EQ})
        """
        assert ctype in self._valid_constraint_types
        cons = self._cons.get(label)
        if cons is not None:
            # Do not re-add, but ensure it's the same type
            assert cons.ctype == ctype
            return
        if ctype == CONS_TYPE_LEQ or ctype == CONS_TYPE_GEQ:
            cons = _Constraint(len(self._leq_cons_list), ctype)
            self._cons[label] = cons
            self._leq_cons_list.append(cons)
        elif ctype == CONS_TYPE_EQ:
            cons = _Constraint(len(self._eq_cons_list), ctype)
            self._cons[label] = cons
            self._eq_cons_list.append(cons)
        else:
            assert False, "Unknown constraint type"

    def set_obj_coeff(self, var_label, coeff):
        """Sets a coefficient of a variable in the objective."""
        self._obj_coeffs[var_label] = coeff

    def set_cons_coeff(self, cons_label, var_label, coeff):
        """Sets a coefficient of a constraint in the LP."""
        self._cons[cons_label].coeffs[var_label] = coeff

    def add_to_cons_coeff(self, cons_label, var_label, add_coeff):
        """Sets a coefficient of a constraint in the LP."""
        val = self._cons[cons_label].coeffs.get(var_label)
        if val is None:
            val = 0
        self._cons[cons_label].coeffs[var_label] = val + add_coeff

    def set_cons_rhs(self, cons_label, value):
        """Sets the right-hand side of a constraint."""
        self._cons[cons_label].rhs = value

    def get_var_id(self, label):
        var = self._vars.get(label)
        assert var is not None
        return var.vid

    def get_num_cons(self):
        return len(self._leq_cons_list), len(self._eq_cons_list)

    def solve(self, solver=None):
        """Solves the LP.

        Args:
          solver: the solver to use ('blas', 'lapack', 'glpk'). Defaults to None,
            which then uses the cvxopt internal default.

        Returns:
          The solution as a dict of var label -> value, one for each variable.
        """
        # From http://cvxopt.org/userguide/coneprog.html#linear-programming,
        # CVXOPT uses the formulation:
        #    minimize: c^t x
        #       s.t.   Gx <= h
        #              Ax = b
        #
        # Here:
        #  - x is the vector the variables
        #  - c is the vector of objective coefficients
        #  - G is the matrix of LEQ (and GEQ) constraint coefficients
        #  - h is the vector or right-hand side values of the LEQ/GEQ constraints
        #  - A is the matrix of equality constraint coefficients
        #  - b is the vector of right-hand side values of the equality constraints
        #
        # This function builds these sparse matrices from the information it has
        # gathered, flipping signs where necessary, and adding equality constraints
        # for the upper and lower bounds of variables. It then calls the cvxopt
        # solver and maps back the values.
        num_vars = len(self._var_list)
        num_eq_cons = len(self._eq_cons_list)
        num_leq_cons = len(self._leq_cons_list)
        for var in self._var_list:
            if var.lb is not None:
                num_leq_cons += 1
            if var.ub is not None:
                num_leq_cons += 1
        # Make the matrices (some need to be dense).
        c = cvxopt.matrix([0.0] * num_vars)
        h = cvxopt.matrix([0.0] * num_leq_cons)
        g_mat = cvxopt.spmatrix([], [], [], (num_leq_cons, num_vars))
        a_mat = None
        b = None
        if num_eq_cons > 0:
            a_mat = cvxopt.spmatrix([], [], [], (num_eq_cons, num_vars))
            b = cvxopt.matrix([0.0] * num_eq_cons)
        # Objective coefficients: c
        for var_label in self._obj_coeffs:
            value = self._obj_coeffs[var_label]
            vid = self._vars[var_label].vid
            if self._objective == OBJ_MAX:
                c[vid] = -value  # negate the value because it's a max
            else:
                c[vid] = value  # min objective matches cvxopt
        # Inequality constraints: G, h
        row = 0
        for cons in self._leq_cons_list:
            # If it's >= then need to negate all coeffs and the rhs
            if cons.rhs is not None:
                h[row] = cons.rhs if cons.ctype == CONS_TYPE_LEQ else -cons.rhs
            for var_label in cons.coeffs:
                value = cons.coeffs[var_label]
                vid = self._vars[var_label].vid
                g_mat[(row, vid)] = value if cons.ctype == CONS_TYPE_LEQ else -value
            row += 1
        # Inequality constraints: variables upper and lower bounds
        for var in self._var_list:
            if var.lb is not None:  # x_i >= lb has to be -x_i <= -lb
                g_mat[(row, var.vid)] = -1.0
                h[row] = -var.lb
                row += 1
            if var.ub is not None:  # x_i <= ub
                g_mat[(row, var.vid)] = 1.0
                h[row] = var.ub
                row += 1
        # Equality constraints: A, b
        if num_eq_cons > 0:
            row = 0
            for cons in self._eq_cons_list:
                b[row] = cons.rhs if cons.rhs is not None else 0.0
                for var_label in cons.coeffs:
                    value = cons.coeffs[var_label]
                    vid = self._vars[var_label].vid
                    a_mat[(row, vid)] = value
                row += 1
        # Solve!
        if num_eq_cons > 0:
            sol = cvxopt.solvers.lp(c, g_mat, h, a_mat, b, solver=solver)
        else:
            sol = cvxopt.solvers.lp(c, g_mat, h, solver=solver)
        return sol["x"]


def solve_zero_sum_matrix_game(game):
    """Solves a matrix game by using linear programming.

    Args:
      game: a pyspiel MatrixGame

    Returns:
      A 4-tuple containing:
        - p0_sol (array-like): probability distribution over row actions
        - p1_sol (array-like): probability distribution over column actions
        - p0_sol_value, expected value to the first player
        - p1_sol_value, expected value to the second player
    """

    # Solving a game for player i (e.g. row player) requires finding a mixed
    # policy over player i's pure strategies (actions) such that a value of the
    # mixed strategy against every opponent pure strategy is maximized.
    #
    # For more detail, please refer to Sec 4.1 of Shoham & Leyton-Brown, 2009:
    # Multiagent Systems: Algorithmic, Game-Theoretic, and Logical Foundations
    # http://www.masfoundations.org/mas.pdf
    #
    # For the row player the LP looks like:
    #    max V
    #     st. sigma_a1 \dot col_0 >= V
    #         sigma_a2 \dot col_1 >= V
    #              .
    #              .
    #         sigma_am \cot col_n >= V
    #         for all i, sigma_ai >= 0
    #         sigma \dot 1 = 1
    assert isinstance(game, pyspiel.MatrixGame)
    assert game.get_type().information == pyspiel.GameType.Information.ONE_SHOT
    assert game.get_type().utility == pyspiel.GameType.Utility.ZERO_SUM
    num_rows = game.num_rows()
    num_cols = game.num_cols()
    cvxopt.solvers.options["show_progress"] = False

    # First, do the row player (player 0).
    lp0 = LinearProgram(OBJ_MAX)
    for r in range(num_rows):  # one var per action / pure strategy
        lp0.add_or_reuse_variable(r, lb=0)
    lp0.add_or_reuse_variable(num_rows)  # V
    lp0.set_obj_coeff(num_rows, 1.0)  # max V
    for c in range(num_cols):
        lp0.add_or_reuse_constraint(c, CONS_TYPE_GEQ)
        for r in range(num_rows):
            lp0.set_cons_coeff(c, r, game.player_utility(0, r, c))
        lp0.set_cons_coeff(c, num_rows, -1.0)  # -V >= 0
    lp0.add_or_reuse_constraint(num_cols + 1, CONS_TYPE_EQ)
    lp0.set_cons_rhs(num_cols + 1, 1.0)
    for r in range(num_rows):
        lp0.set_cons_coeff(num_cols + 1, r, 1.0)
    sol = lp0.solve()
    p0_sol = sol[:-1]
    p0_sol_val = sol[-1]

    # Now, the column player (player 1).
    lp1 = LinearProgram(OBJ_MAX)
    for c in range(num_cols):  # one var per action / pure strategy
        lp1.add_or_reuse_variable(c, lb=0)
    lp1.add_or_reuse_variable(num_cols)  # V
    lp1.set_obj_coeff(num_cols, 1)  # max V
    for r in range(num_rows):
        lp1.add_or_reuse_constraint(r, CONS_TYPE_GEQ)
        for c in range(num_cols):
            lp1.set_cons_coeff(r, c, game.player_utility(1, r, c))
        lp1.set_cons_coeff(r, num_cols, -1.0)  # -V >= 0
    lp1.add_or_reuse_constraint(num_rows + 1, CONS_TYPE_EQ)
    lp1.set_cons_rhs(num_rows + 1, 1.0)
    for c in range(num_cols):
        lp1.set_cons_coeff(num_rows + 1, c, 1.0)
    sol = lp1.solve()
    p1_sol = sol[:-1]
    p1_sol_val = sol[-1]

    return p0_sol, p1_sol, p0_sol_val, p1_sol_val